"""
Test search cache stampede protection functionality.
"""

import concurrent.futures
import threading
import time
import unittest
from tempfile import TemporaryDirectory
from typing import List, Dict, Any

from src.local_deep_research.utilities.search_cache import SearchCache


class TestSearchCacheStampede(unittest.TestCase):
    """Test that SearchCache properly prevents cache stampedes."""

    def setUp(self):
        """Set up test cache with temporary directory."""
        self.temp_dir = TemporaryDirectory()
        self.cache = SearchCache(
            cache_dir=self.temp_dir.name, max_memory_items=100, default_ttl=60
        )

        # Track API calls
        self.api_call_count = 0
        self.api_call_lock = threading.Lock()
        self.api_call_threads = []

    def tearDown(self):
        """Clean up temporary directory."""
        self.temp_dir.cleanup()

    def mock_api_call(self, query: str) -> List[Dict[str, Any]]:
        """Mock API call that simulates slow external search."""
        with self.api_call_lock:
            self.api_call_count += 1
            self.api_call_threads.append(threading.current_thread().name)

        # Simulate network latency
        time.sleep(0.1)

        # Return mock results
        return [
            {"title": f"Result 1 for {query}", "url": "http://example.com/1"},
            {"title": f"Result 2 for {query}", "url": "http://example.com/2"},
        ]

    def test_single_thread_caching(self):
        """Test basic caching works for a single thread."""
        query = "test query single"

        # First call - should miss cache and call API
        results = self.cache.get_or_fetch(
            query, lambda: self.mock_api_call(query)
        )

        self.assertEqual(self.api_call_count, 1)
        self.assertIsNotNone(results)
        self.assertEqual(len(results), 2)

        # Second call - should hit cache
        results2 = self.cache.get_or_fetch(
            query, lambda: self.mock_api_call(query)
        )

        self.assertEqual(self.api_call_count, 1)  # No additional API call
        self.assertEqual(results, results2)

    def test_stampede_protection(self):
        """Test that multiple concurrent requests only trigger one API call."""
        query = "test query stampede"
        num_threads = 10
        results_list = []
        barrier = threading.Barrier(num_threads)

        def fetch_with_cache(thread_id):
            """Function each thread will run."""
            barrier.wait()  # Ensure all threads start simultaneously
            results = self.cache.get_or_fetch(
                query, lambda: self.mock_api_call(query)
            )
            results_list.append((thread_id, results))
            return results

        # Launch multiple threads concurrently
        with concurrent.futures.ThreadPoolExecutor(
            max_workers=num_threads
        ) as executor:
            futures = [
                executor.submit(fetch_with_cache, i) for i in range(num_threads)
            ]
            concurrent.futures.wait(futures)

        # Check that only ONE API call was made despite multiple threads
        self.assertEqual(
            self.api_call_count,
            1,
            f"Expected 1 API call but got {self.api_call_count}. "
            f"Threads that called API: {self.api_call_threads}",
        )

        # Check all threads got the same results
        self.assertEqual(len(results_list), num_threads)
        first_result = results_list[0][1]
        for thread_id, result in results_list:
            self.assertEqual(
                result,
                first_result,
                f"Thread {thread_id} got different results",
            )

    def test_different_queries_not_blocked(self):
        """Test that different queries can be fetched concurrently."""
        queries = [f"query_{i}" for i in range(5)]
        results_dict = {}

        def fetch_query(query):
            results = self.cache.get_or_fetch(
                query, lambda: self.mock_api_call(query)
            )
            results_dict[query] = results
            return results

        # Launch threads for different queries
        with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
            futures = [executor.submit(fetch_query, query) for query in queries]
            concurrent.futures.wait(futures)

        # Each unique query should trigger exactly one API call
        self.assertEqual(
            self.api_call_count,
            5,
            f"Expected 5 API calls for 5 different queries, got {self.api_call_count}",
        )

        # Verify each query got cached
        for query in queries:
            cached = self.cache.get(query)
            self.assertIsNotNone(cached)
            self.assertEqual(cached, results_dict[query])

    def test_expired_cache_refetch(self):
        """Test that expired cache entries trigger refetch with stampede protection."""
        query = "test expiring query"

        # First fetch with very short TTL
        self.cache.get_or_fetch(
            query,
            lambda: self.mock_api_call(query),
            ttl=1,  # 1 second TTL
        )

        self.assertEqual(self.api_call_count, 1)

        # Wait for cache to expire
        time.sleep(1.5)

        # Multiple threads try to fetch expired entry
        num_threads = 5
        results_list = []
        barrier = threading.Barrier(num_threads)

        def fetch_expired():
            barrier.wait()  # Ensure all threads start simultaneously
            return self.cache.get_or_fetch(
                query, lambda: self.mock_api_call(query)
            )

        with concurrent.futures.ThreadPoolExecutor(
            max_workers=num_threads
        ) as executor:
            futures = [
                executor.submit(fetch_expired) for _ in range(num_threads)
            ]
            for future in concurrent.futures.as_completed(futures):
                results_list.append(future.result())

        # Should have made exactly 2 API calls total (initial + refetch after expiry)
        self.assertEqual(
            self.api_call_count,
            2,
            f"Expected 2 API calls (initial + refetch), got {self.api_call_count}",
        )

        # All threads should get the same refreshed results
        for results in results_list:
            self.assertEqual(results, results_list[0])

    def test_get_during_fetch(self):
        """Test that get() returns None while another thread is fetching."""
        query = "test get during fetch"
        fetch_started = threading.Event()
        fetch_complete = threading.Event()

        def slow_fetch():
            fetch_started.set()
            time.sleep(0.5)  # Simulate slow API
            fetch_complete.set()
            return self.mock_api_call(query)

        def thread1_fetch():
            """First thread - initiates fetch."""
            return self.cache.get_or_fetch(query, slow_fetch)

        def thread2_get():
            """Second thread - checks cache during fetch."""
            fetch_started.wait()  # Ensure fetch has started
            result = self.cache.get(query)
            return result

        def thread3_get_or_fetch():
            """Third thread - uses get_or_fetch during fetch."""
            fetch_started.wait()  # Ensure fetch has started
            time.sleep(0.1)  # Small delay
            result = self.cache.get_or_fetch(
                query, lambda: self.mock_api_call(query)
            )
            return result

        with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
            future1 = executor.submit(thread1_fetch)
            future2 = executor.submit(thread2_get)
            future3 = executor.submit(thread3_get_or_fetch)

            # Thread 2 should return None (cache miss during fetch)
            result2 = future2.result()
            self.assertIsNone(result2, "get() should return None during fetch")

            # Thread 3 should wait and get results without calling API again
            result3 = future3.result()
            self.assertIsNotNone(
                result3, "get_or_fetch() should wait and get results"
            )

            # Thread 1 completes the fetch
            result1 = future1.result()
            self.assertIsNotNone(result1)
            self.assertEqual(result1, result3)

            # Only one API call should have been made
            self.assertEqual(self.api_call_count, 1)

    def test_fetch_error_handling(self):
        """Test that fetch errors don't break stampede protection."""
        query = "test error query"
        attempt_count = 0

        def failing_fetch():
            nonlocal attempt_count
            attempt_count += 1
            if attempt_count < 2:
                raise Exception("API error")
            return self.mock_api_call(query)

        # First attempt fails
        result1 = self.cache.get_or_fetch(query, failing_fetch)
        self.assertIsNone(result1)
        self.assertEqual(attempt_count, 1)

        # Second attempt succeeds
        result2 = self.cache.get_or_fetch(query, failing_fetch)
        self.assertIsNotNone(result2)
        self.assertEqual(attempt_count, 2)

        # Third attempt uses cache
        result3 = self.cache.get_or_fetch(query, failing_fetch)
        self.assertEqual(result2, result3)
        self.assertEqual(attempt_count, 2)  # No additional attempt

    def test_cleanup_of_locks(self):
        """Test that locks are cleaned up after fetching."""
        query = "test cleanup"

        # Fetch and cache
        self.cache.get_or_fetch(query, lambda: self.mock_api_call(query))

        # Wait for cleanup thread to run (cleanup delay is 2 seconds)
        time.sleep(3)

        # Check that locks and events have been cleaned up
        query_hash = self.cache._get_query_hash(query)
        self.assertNotIn(query_hash, self.cache._fetch_locks)
        self.assertNotIn(query_hash, self.cache._fetch_events)
        self.assertNotIn(query_hash, self.cache._fetch_results)


if __name__ == "__main__":
    unittest.main()
